{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YVfdLsnHFrZ6"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "mSVgbpixegCb"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FZ2T25awqPPX"
      },
      "source": [
        "# サブクラス化によるレイヤとモデルの新規作成"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Q8vYTLTqJ1i"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://www.tensorflow.org/guide/keras/custom_layers_and_models\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\"> TensorFlow.orgで表示</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/guide/keras/custom_layers_and_models.ipynb\">     <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">Google Colabで実行</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ja/guide/keras/custom_layers_and_models.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">GitHub でソースを表示{</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/guide/keras/custom_layers_and_models.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\">ノートブックをダウンロード/a0}</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kKyceCkqt8kO"
      },
      "source": [
        "## セットアップ"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lhsGftkq8III"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "from tensorflow import keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ox9WJETD27aT"
      },
      "source": [
        "## `Layer` クラス：状態（重み）といくつかの計算の組み合わせ\n",
        "\n",
        "Kerasの中心的な抽象概念の1つは、`Layer`クラスです。 レイヤーは、状態（レイヤーの「重み」）と入力から出力への変換（「呼び出し」、レイヤーのフォワードパス）をカプセル化します。\n",
        "\n",
        "以下の密に接続されたレイヤーで、状態は変数`w`と`b`をもちます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "psLn5u4RUhru"
      },
      "outputs": [],
      "source": [
        "class Linear(keras.layers.Layer):\n",
        "    def __init__(self, units=32, input_dim=32):\n",
        "        super(Linear, self).__init__()\n",
        "        w_init = tf.random_normal_initializer()\n",
        "        self.w = tf.Variable(\n",
        "            initial_value=w_init(shape=(input_dim, units), dtype=\"float32\"),\n",
        "            trainable=True,\n",
        "        )\n",
        "        b_init = tf.zeros_initializer()\n",
        "        self.b = tf.Variable(\n",
        "            initial_value=b_init(shape=(units,), dtype=\"float32\"), trainable=True\n",
        "        )\n",
        "\n",
        "    def call(self, inputs):\n",
        "        return tf.matmul(inputs, self.w) + self.b\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xa4ZKbcgY2ws"
      },
      "source": [
        "Python関数のように、レイヤーをテンソル入力で呼び出して使用します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wCFLPZvTII3w"
      },
      "outputs": [],
      "source": [
        "x = tf.ones((2, 2))\n",
        "linear_layer = Linear(4, 2)\n",
        "y = linear_layer(x)\n",
        "print(y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VU6qtjLJ7OoU"
      },
      "source": [
        "重み`w`および`b`をレイヤー属性として設定すると、レイヤーによって自動的に追跡されるので注意してください。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jWeZhh3V7wJa"
      },
      "outputs": [],
      "source": [
        "assert linear_layer.weights == [linear_layer.w, linear_layer.b]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "u8UFOXQ7wND7"
      },
      "source": [
        "また、`add_weight()`を使用すると、レイヤーにより迅速に重みを追加することができます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JrVpdNTTIVB7"
      },
      "outputs": [],
      "source": [
        "class Linear(keras.layers.Layer):\n",
        "    def __init__(self, units=32, input_dim=32):\n",
        "        super(Linear, self).__init__()\n",
        "        self.w = self.add_weight(\n",
        "            shape=(input_dim, units), initializer=\"random_normal\", trainable=True\n",
        "        )\n",
        "        self.b = self.add_weight(shape=(units,), initializer=\"zeros\", trainable=True)\n",
        "\n",
        "    def call(self, inputs):\n",
        "        return tf.matmul(inputs, self.w) + self.b\n",
        "\n",
        "\n",
        "x = tf.ones((2, 2))\n",
        "linear_layer = Linear(4, 2)\n",
        "y = linear_layer(x)\n",
        "print(y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0FWNOtdYhuJy"
      },
      "source": [
        "## トレーニングできない重みをレイヤーに追加\n",
        "\n",
        "トレーニング可能な重みだけでなく、トレーニング不可能な重みもレイヤーに追加できます。このような重みは、レイヤーをトレーニングする際、バックプロパゲーションで考慮されません。\n",
        "\n",
        "以下のようにトレーニング不可能な重みを追加して使用します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6yTiS9tI2G8c"
      },
      "outputs": [],
      "source": [
        "class ComputeSum(keras.layers.Layer):\n",
        "    def __init__(self, input_dim):\n",
        "        super(ComputeSum, self).__init__()\n",
        "        self.total = tf.Variable(initial_value=tf.zeros((input_dim,)), trainable=False)\n",
        "\n",
        "    def call(self, inputs):\n",
        "        self.total.assign_add(tf.reduce_sum(inputs, axis=0))\n",
        "        return self.total\n",
        "\n",
        "\n",
        "x = tf.ones((2, 2))\n",
        "my_sum = ComputeSum(2)\n",
        "y = my_sum(x)\n",
        "print(y.numpy())\n",
        "y = my_sum(x)\n",
        "print(y.numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I2jZ13RpWSJZ"
      },
      "source": [
        "これは`layer.weights`の一部ですが、トレーニング不可能な重みとして分類されます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zHmQSpo4D4QQ"
      },
      "outputs": [],
      "source": [
        "print(\"weights:\", len(my_sum.weights))\n",
        "print(\"non-trainable weights:\", len(my_sum.non_trainable_weights))\n",
        "\n",
        "# It's not included in the trainable weights:\n",
        "print(\"trainable_weights:\", my_sum.trainable_weights)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wjFeyyncwHsn"
      },
      "source": [
        "## ベストプラクティス：入力の形状がわかるまで重みづけを延期する\n",
        "\n",
        "上記の`Linear` レイヤーは、`input_dim `引数を受け取り、この引数は`__init__()`の重み`w`および`b` の形状を計算するために使用されました。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N9B8k9LtrJbl"
      },
      "outputs": [],
      "source": [
        "class Linear(keras.layers.Layer):\n",
        "    def __init__(self, units=32, input_dim=32):\n",
        "        super(Linear, self).__init__()\n",
        "        self.w = self.add_weight(\n",
        "            shape=(input_dim, units), initializer=\"random_normal\", trainable=True\n",
        "        )\n",
        "        self.b = self.add_weight(shape=(units,), initializer=\"zeros\", trainable=True)\n",
        "\n",
        "    def call(self, inputs):\n",
        "        return tf.matmul(inputs, self.w) + self.b\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aWx8tcSpyO7n"
      },
      "source": [
        "多くの場合、入力値は事前に分からないので、その値が分かった時に後から（レイヤーをインスタンス化した後など）重みを定義したいと思うこともあるでしょう。\n",
        "\n",
        "Keras APIでは、以下のようにレイヤーの`build(self, inputs_shape)`メソッドでレイヤーの重みを定義することをお勧めします。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Cow04aRb83DJ"
      },
      "outputs": [],
      "source": [
        "class Linear(keras.layers.Layer):\n",
        "    def __init__(self, units=32):\n",
        "        super(Linear, self).__init__()\n",
        "        self.units = units\n",
        "\n",
        "    def build(self, input_shape):\n",
        "        self.w = self.add_weight(\n",
        "            shape=(input_shape[-1], self.units),\n",
        "            initializer=\"random_normal\",\n",
        "            trainable=True,\n",
        "        )\n",
        "        self.b = self.add_weight(\n",
        "            shape=(self.units,), initializer=\"random_normal\", trainable=True\n",
        "        )\n",
        "\n",
        "    def call(self, inputs):\n",
        "        return tf.matmul(inputs, self.w) + self.b\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ButahocyBR7y"
      },
      "source": [
        "レイヤーの`__call__()`メソッドは、最初に呼び出されたときに自動的にビルドを実行します。後から重みを定義できる使いやすいレイヤーができました。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8E8ufQGU2r8u"
      },
      "outputs": [],
      "source": [
        "# At instantiation, we don't know on what inputs this is going to get called\n",
        "linear_layer = Linear(32)\n",
        "\n",
        "# The layer's weights are created dynamically the first time the layer is called\n",
        "y = linear_layer(x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YYE0fkm8bIg1"
      },
      "source": [
        "## レイヤーは再帰的に構成可能\n",
        "\n",
        "レイヤーインスタンスを別のレイヤーの属性として割り当てると、外部レイヤーは内部レイヤーの重みを追跡し始めます。\n",
        "\n",
        "そのようなサブレイヤーは、`__init__()`メソッドで作成することをお勧めします（通常、サブレイヤーにはビルドメソッドがあるため、外部レイヤーがビルドされるときにサブレイヤーがビルドされます）。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hJr89U0qOJqR"
      },
      "outputs": [],
      "source": [
        "# Let's assume we are reusing the Linear class\n",
        "# with a `build` method that we defined above.\n",
        "\n",
        "\n",
        "class MLPBlock(keras.layers.Layer):\n",
        "    def __init__(self):\n",
        "        super(MLPBlock, self).__init__()\n",
        "        self.linear_1 = Linear(32)\n",
        "        self.linear_2 = Linear(32)\n",
        "        self.linear_3 = Linear(1)\n",
        "\n",
        "    def call(self, inputs):\n",
        "        x = self.linear_1(inputs)\n",
        "        x = tf.nn.relu(x)\n",
        "        x = self.linear_2(x)\n",
        "        x = tf.nn.relu(x)\n",
        "        return self.linear_3(x)\n",
        "\n",
        "\n",
        "mlp = MLPBlock()\n",
        "y = mlp(tf.ones(shape=(3, 64)))  # The first call to the `mlp` will create the weights\n",
        "print(\"weights:\", len(mlp.weights))\n",
        "print(\"trainable weights:\", len(mlp.trainable_weights))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wQUEr9dZqxaN"
      },
      "source": [
        "## `add_loss()`メソッド\n",
        "\n",
        "レイヤーの`call()`メソッドを記述する際、トレーニングループを作成するときに、後で使用する損失テンソルを作成できます。これは、`self.add_loss(value)`を呼び出すことで実行できます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W99EJ41lQhOR"
      },
      "outputs": [],
      "source": [
        "# A layer that creates an activity regularization loss\n",
        "class ActivityRegularizationLayer(keras.layers.Layer):\n",
        "    def __init__(self, rate=1e-2):\n",
        "        super(ActivityRegularizationLayer, self).__init__()\n",
        "        self.rate = rate\n",
        "\n",
        "    def call(self, inputs):\n",
        "        self.add_loss(self.rate * tf.reduce_sum(inputs))\n",
        "        return inputs\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gjB17DfePyIm"
      },
      "source": [
        "これらの損失（内部レイヤーによって作成されたものを含む）は、 `layer.losses`で取得できます。このプロパティは、すべての`__call__()`の開始時にトップレベルレイヤーにリセットされるので、 `layer.losses` には、最後のフォワードパス時に定義された損失値が常に含まれます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "48QiYwJRIe96"
      },
      "outputs": [],
      "source": [
        "class OuterLayer(keras.layers.Layer):\n",
        "    def __init__(self):\n",
        "        super(OuterLayer, self).__init__()\n",
        "        self.activity_reg = ActivityRegularizationLayer(1e-2)\n",
        "\n",
        "    def call(self, inputs):\n",
        "        return self.activity_reg(inputs)\n",
        "\n",
        "\n",
        "layer = OuterLayer()\n",
        "assert len(layer.losses) == 0  # No losses yet since the layer has never been called\n",
        "\n",
        "_ = layer(tf.zeros(1, 1))\n",
        "assert len(layer.losses) == 1  # We created one loss value\n",
        "\n",
        "# `layer.losses` gets reset at the start of each __call__\n",
        "_ = layer(tf.zeros(1, 1))\n",
        "assert len(layer.losses) == 1  # This is the loss created during the call above"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "72VW9A52XfTT"
      },
      "source": [
        "さらに、`loss`プロパティには、内部レイヤーの重みに対して定義された正則化損失も含まれます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OBrP81sfVWUn"
      },
      "outputs": [],
      "source": [
        "class OuterLayerWithKernelRegularizer(keras.layers.Layer):\n",
        "    def __init__(self):\n",
        "        super(OuterLayerWithKernelRegularizer, self).__init__()\n",
        "        self.dense = keras.layers.Dense(\n",
        "            32, kernel_regularizer=tf.keras.regularizers.l2(1e-3)\n",
        "        )\n",
        "\n",
        "    def call(self, inputs):\n",
        "        return self.dense(inputs)\n",
        "\n",
        "\n",
        "layer = OuterLayerWithKernelRegularizer()\n",
        "_ = layer(tf.zeros((1, 1)))\n",
        "\n",
        "# This is `1e-3 * sum(layer.dense.kernel ** 2)`,\n",
        "# created by the `kernel_regularizer` above.\n",
        "print(layer.losses)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7P88QF0urVzA"
      },
      "source": [
        "これらの損失は、以下のようにトレーニングループを記述するときに考慮されることを意図しています。\n",
        "\n",
        "```python\n",
        "# Instantiate an optimizer.\n",
        "optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)\n",
        "loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
        "\n",
        "# Iterate over the batches of a dataset.\n",
        "for x_batch_train, y_batch_train in train_dataset:\n",
        "  with tf.GradientTape() as tape:\n",
        "    logits = layer(x_batch_train)  # Logits for this minibatch\n",
        "    # Loss value for this minibatch\n",
        "    loss_value = loss_fn(y_batch_train, logits)\n",
        "    # Add extra losses created during this forward pass:\n",
        "    loss_value += sum(model.losses)\n",
        "\n",
        "  grads = tape.gradient(loss_value, model.trainable_weights)\n",
        "  optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9Wo9mCyvIheY"
      },
      "source": [
        "トレーニングループの記述に関する詳細は、[「トレーニングループを新規作成するためのガイド」](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch/)を参照してください。\n",
        "\n",
        "これらの損失は、`fit()` とシームレスに連携します（自動的に集計され、メジャーな損失がある場合にはそれに追加されます）。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "prhlyHQU8eqT"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "\n",
        "inputs = keras.Input(shape=(3,))\n",
        "outputs = ActivityRegularizationLayer()(inputs)\n",
        "model = keras.Model(inputs, outputs)\n",
        "\n",
        "# If there is a loss passed in `compile`, thee regularization\n",
        "# losses get added to it\n",
        "model.compile(optimizer=\"adam\", loss=\"mse\")\n",
        "model.fit(np.random.random((2, 3)), np.random.random((2, 3)))\n",
        "\n",
        "# It's also possible not to pass any loss in `compile`,\n",
        "# since the model already has a loss to minimize, via the `add_loss`\n",
        "# call during the forward pass!\n",
        "model.compile(optimizer=\"adam\")\n",
        "model.fit(np.random.random((2, 3)), np.random.random((2, 3)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5Fw96x84S8HS"
      },
      "source": [
        "## `add_metric()`メソッド\n",
        "\n",
        "`add_loss()`と同様に、レイヤーでは`add_metric()`メソッドでトレーニング中の数量の移動平均を追跡できます。\n",
        "\n",
        "「ロジスティックエンドポイント」レイヤーでは、入力として予測とターゲットを受け取り、 `add_loss()`を介して追跡する損失を計算し、`add_metric()`を介して追跡する精度スカラーを計算します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kjhgiCfke2Dx"
      },
      "outputs": [],
      "source": [
        "class LogisticEndpoint(keras.layers.Layer):\n",
        "    def __init__(self, name=None):\n",
        "        super(LogisticEndpoint, self).__init__(name=name)\n",
        "        self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)\n",
        "        self.accuracy_fn = keras.metrics.BinaryAccuracy()\n",
        "\n",
        "    def call(self, targets, logits, sample_weights=None):\n",
        "        # Compute the training-time loss value and add it\n",
        "        # to the layer using `self.add_loss()`.\n",
        "        loss = self.loss_fn(targets, logits, sample_weights)\n",
        "        self.add_loss(loss)\n",
        "\n",
        "        # Log accuracy as a metric and add it\n",
        "        # to the layer using `self.add_metric()`.\n",
        "        acc = self.accuracy_fn(targets, logits, sample_weights)\n",
        "        self.add_metric(acc, name=\"accuracy\")\n",
        "\n",
        "        # Return the inference-time prediction tensor (for `.predict()`).\n",
        "        return tf.nn.softmax(logits)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eEcShLRgzCcp"
      },
      "source": [
        "この方法で追跡されるメトリクスは、` layer.metrics `で取得できます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vK9Y6ierSyqd"
      },
      "outputs": [],
      "source": [
        "layer = LogisticEndpoint()\n",
        "\n",
        "targets = tf.ones((2, 2))\n",
        "logits = tf.ones((2, 2))\n",
        "y = layer(targets, logits)\n",
        "\n",
        "print(\"layer.metrics:\", layer.metrics)\n",
        "print(\"current accuracy value:\", float(layer.metrics[0].result()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ogfTmUPmHqPS"
      },
      "source": [
        "`add_loss()`の場合と同様に、これらのメトリックは`fit()`により追跡されます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YEzldbjt7oF8"
      },
      "outputs": [],
      "source": [
        "inputs = keras.Input(shape=(3,), name=\"inputs\")\n",
        "targets = keras.Input(shape=(10,), name=\"targets\")\n",
        "logits = keras.layers.Dense(10)(inputs)\n",
        "predictions = LogisticEndpoint(name=\"predictions\")(logits, targets)\n",
        "\n",
        "model = keras.Model(inputs=[inputs, targets], outputs=predictions)\n",
        "model.compile(optimizer=\"adam\")\n",
        "\n",
        "data = {\n",
        "    \"inputs\": np.random.random((3, 3)),\n",
        "    \"targets\": np.random.random((3, 10)),\n",
        "}\n",
        "model.fit(data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jvZNZlgFFPGW"
      },
      "source": [
        "## レイヤーのシリアル化をオプションで有効化\n",
        "\n",
        "カスタムレイヤーを[Functionalモデル](https://www.tensorflow.org/guide/keras/functional/)の一部としてシリアル化可能にする必要がある場合は、オプションで`get_config()` メソッドを実装できます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3EWDINbqDVRZ"
      },
      "outputs": [],
      "source": [
        "class Linear(keras.layers.Layer):\n",
        "    def __init__(self, units=32):\n",
        "        super(Linear, self).__init__()\n",
        "        self.units = units\n",
        "\n",
        "    def build(self, input_shape):\n",
        "        self.w = self.add_weight(\n",
        "            shape=(input_shape[-1], self.units),\n",
        "            initializer=\"random_normal\",\n",
        "            trainable=True,\n",
        "        )\n",
        "        self.b = self.add_weight(\n",
        "            shape=(self.units,), initializer=\"random_normal\", trainable=True\n",
        "        )\n",
        "\n",
        "    def call(self, inputs):\n",
        "        return tf.matmul(inputs, self.w) + self.b\n",
        "\n",
        "    def get_config(self):\n",
        "        return {\"units\": self.units}\n",
        "\n",
        "\n",
        "# Now you can recreate the layer from its config:\n",
        "layer = Linear(64)\n",
        "config = layer.get_config()\n",
        "print(config)\n",
        "new_layer = Linear.from_config(config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RBvNVeGPxhRE"
      },
      "source": [
        "基底 `Layer`クラスの`__init__()`メソッドでは、`name`や`dtype`などのキーワード引数を使えます。これらの引数を `__init__()` で親クラスに渡し、レイヤーコンフィギュレーションに含めることをお勧めします。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fRdhwwqKTZxI"
      },
      "outputs": [],
      "source": [
        "class Linear(keras.layers.Layer):\n",
        "    def __init__(self, units=32, **kwargs):\n",
        "        super(Linear, self).__init__(**kwargs)\n",
        "        self.units = units\n",
        "\n",
        "    def build(self, input_shape):\n",
        "        self.w = self.add_weight(\n",
        "            shape=(input_shape[-1], self.units),\n",
        "            initializer=\"random_normal\",\n",
        "            trainable=True,\n",
        "        )\n",
        "        self.b = self.add_weight(\n",
        "            shape=(self.units,), initializer=\"random_normal\", trainable=True\n",
        "        )\n",
        "\n",
        "    def call(self, inputs):\n",
        "        return tf.matmul(inputs, self.w) + self.b\n",
        "\n",
        "    def get_config(self):\n",
        "        config = super(Linear, self).get_config()\n",
        "        config.update({\"units\": self.units})\n",
        "        return config\n",
        "\n",
        "\n",
        "layer = Linear(64)\n",
        "config = layer.get_config()\n",
        "print(config)\n",
        "new_layer = Linear.from_config(config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cqHlieM5gqgE"
      },
      "source": [
        "構成からレイヤーを逆シリアル化するときに柔軟性が必要な場合は、 `from_config()` クラスメソッドをオーバーライドすることもできます。これは、`from_config()`の基本実装です。\n",
        "\n",
        "```python\n",
        "def from_config(cls, config):   return cls(**config)\n",
        "```\n",
        "\n",
        "シリアル化と保存の詳細については、[「モデルの保存とシリアル化のガイド」](https://www.tensorflow.org/guide/keras/save_and_serialize/)を参照してください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qCXh57DUwJd8"
      },
      "source": [
        "## `call()` メソッドの特権付き `トレーニング`引数\n",
        "\n",
        "一部のレイヤー、特に`BatchNormalization`レイヤーと`Dropout`レイヤーでは、トレーニングと推論での動作が異なります。このようなレイヤーでは、`call()` メソッドで`トレーニング`引数（ブール値）を公開するのが標準的なメソッドです。\n",
        "\n",
        "この引数を `call()`で公開することで、組み込みのトレーニングループと評価ループ（`fit()`など）を有効にして、トレーニングと推論でレイヤーを正しく使用できます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BP35dXiYoTcD"
      },
      "outputs": [],
      "source": [
        "class CustomDropout(keras.layers.Layer):\n",
        "    def __init__(self, rate, **kwargs):\n",
        "        super(CustomDropout, self).__init__(**kwargs)\n",
        "        self.rate = rate\n",
        "\n",
        "    def call(self, inputs, training=None):\n",
        "        if training:\n",
        "            return tf.nn.dropout(inputs, rate=self.rate)\n",
        "        return inputs\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OudWKhz1iQXP"
      },
      "source": [
        "## `call()` メソッドの特権付き `マスク`引数\n",
        "\n",
        "`call()`がサポートするもう１つの特権引数は、`マスク`引数です。\n",
        "\n",
        "これはすべてのKeras RNNレイヤーにあります。 マスクは、時系列データを処理するときに特定の入力タイムステップをスキップするために使用されるブールテンソルです。（入力するタイムステップごとに1つのブール値を返します）。\n",
        "\n",
        "前のレイヤーによってマスクが生成されると、Kerasは、それをサポートするレイヤーの正しい`マスク` 引数を自動的に `__call__()`に渡します。 マスク生成レイヤーは、`mask_zero=True`で構成された`Embedding`レイヤーと、`Masking` レイヤーです。\n",
        "\n",
        "マスキングの詳細と、マスキングが有効なレイヤーの記述については、 [「マスキングとパディングのガイド」](https://www.tensorflow.org/guide/keras/masking_and_padding/)を参照してください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6n3ZYo8abHGd"
      },
      "source": [
        "## `モデル`クラス\n",
        "\n",
        "一般的に、`レイヤー`クラスを使用して内部計算ブロックを定義し、`モデル`クラスを使用して外部モデル（トレーニングするオブジェクト）を定義します。\n",
        "\n",
        "たとえば、ResNet50モデルでは、`レイヤー`をサブクラス化する複数のResNetブロックと、ResNet50ネットワーク全体を包含する１つの `モデル`があります。\n",
        "\n",
        "`モデル`クラスは`レイヤー`クラスと同じAPIをもちますが、以下の点で異なります。\n",
        "\n",
        "- 組み込みのトレーニング、評価、予測ループを公開します(`model.fit()`、`model.evaluate()`、`model.predict()`)。\n",
        "- `model.layers`プロパティを介して、その内部レイヤーのリストを公開します。\n",
        "- 保存およびシリアル化APIを公開します(`save()`、`save_weights()`...)\n",
        "\n",
        "実質的に`レイヤー`クラスは、文献で「レイヤー」（「畳み込みレイヤー」または「リカレントレイヤー」など）または「ブロック」（「ResNetブロック」または「Inceptionブロック」など）と呼ばれているものに対応します。\n",
        "\n",
        "一方、`モデル`クラスは、文献で「モデル」（「ディープラーニングモデル」など）または「ネットワーク」（「ディープニューラルネットワーク」など）と呼ばれているものに対応します 。\n",
        "\n",
        "`レイヤー`クラスまたは`モデル`クラスのどちらを使用すべきか迷っている場合は、次の点を確認してください。 `fit()`や`save()`を呼び出す必要がある場合は、`モデル`を使用してください。 その他の場合（クラスがより大きなシステムの単なるブロックである場合や自分でトレーニングを記述してコードを保存する場合）は、`レイヤー`を使用します。\n",
        "\n",
        "たとえば、上のmini-resnetの例を使用して`モデル` を構築し、`fit()`でトレーニングし、`save_weights()`で保存できます。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FNGSP4aBm6fv"
      },
      "source": [
        "```python\n",
        "class ResNet(tf.keras.Model):\n",
        "\n",
        "    def __init__(self):\n",
        "        super(ResNet, self).__init__()\n",
        "        self.block_1 = ResNetBlock()\n",
        "        self.block_2 = ResNetBlock()\n",
        "        self.global_pool = layers.GlobalAveragePooling2D()\n",
        "        self.classifier = Dense(num_classes)\n",
        "\n",
        "    def call(self, inputs):\n",
        "        x = self.block_1(inputs)\n",
        "        x = self.block_2(x)\n",
        "        x = self.global_pool(x)\n",
        "        return self.classifier(x)\n",
        "\n",
        "\n",
        "resnet = ResNet()\n",
        "dataset = ...\n",
        "resnet.fit(dataset, epochs=10)\n",
        "resnet.save(filepath)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Cdhz04jvl5pu"
      },
      "source": [
        "## まとめ：エンドツーエンドの例\n",
        "\n",
        "今まで学んだことを復習しましょう。\n",
        "\n",
        "- `レイヤー`は、状態（`__init__()`または `build()`で作成）といくつかの計算（`call()`で定義）をカプセル化します。\n",
        "- レイヤーを再帰的にネストし、新しい大きな計算ブロックを作成できます。\n",
        "- レイヤーは、`add_loss()`および `add_metric()`を介して、損失（通常は正則化損失）およびメトリックを作成および追跡できます。\n",
        "- トレーニングの対象となる外部コンテナは、`モデル`です。`モデル`は`レイヤー`に似ていますが、トレーニングとシリアル化ユーティリティが追加されます。\n",
        "\n",
        "これらすべてをエンドツーエンドの例にまとめましょう。Variational AutoEncoder (VAE)を実装し、 MNISTディジットでトレーニングします。\n",
        "\n",
        "VAEは`モデル`のサブクラスで、` Layer `をサブクラス化するLayerのネストされた構成として構築されます。 正則化損失（KLダイバージェンス）が特徴です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MkWGvP01ZsIu"
      },
      "outputs": [],
      "source": [
        "from tensorflow.keras import layers\n",
        "\n",
        "\n",
        "class Sampling(layers.Layer):\n",
        "    \"\"\"Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.\"\"\"\n",
        "\n",
        "    def call(self, inputs):\n",
        "        z_mean, z_log_var = inputs\n",
        "        batch = tf.shape(z_mean)[0]\n",
        "        dim = tf.shape(z_mean)[1]\n",
        "        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))\n",
        "        return z_mean + tf.exp(0.5 * z_log_var) * epsilon\n",
        "\n",
        "\n",
        "class Encoder(layers.Layer):\n",
        "    \"\"\"Maps MNIST digits to a triplet (z_mean, z_log_var, z).\"\"\"\n",
        "\n",
        "    def __init__(self, latent_dim=32, intermediate_dim=64, name=\"encoder\", **kwargs):\n",
        "        super(Encoder, self).__init__(name=name, **kwargs)\n",
        "        self.dense_proj = layers.Dense(intermediate_dim, activation=\"relu\")\n",
        "        self.dense_mean = layers.Dense(latent_dim)\n",
        "        self.dense_log_var = layers.Dense(latent_dim)\n",
        "        self.sampling = Sampling()\n",
        "\n",
        "    def call(self, inputs):\n",
        "        x = self.dense_proj(inputs)\n",
        "        z_mean = self.dense_mean(x)\n",
        "        z_log_var = self.dense_log_var(x)\n",
        "        z = self.sampling((z_mean, z_log_var))\n",
        "        return z_mean, z_log_var, z\n",
        "\n",
        "\n",
        "class Decoder(layers.Layer):\n",
        "    \"\"\"Converts z, the encoded digit vector, back into a readable digit.\"\"\"\n",
        "\n",
        "    def __init__(self, original_dim, intermediate_dim=64, name=\"decoder\", **kwargs):\n",
        "        super(Decoder, self).__init__(name=name, **kwargs)\n",
        "        self.dense_proj = layers.Dense(intermediate_dim, activation=\"relu\")\n",
        "        self.dense_output = layers.Dense(original_dim, activation=\"sigmoid\")\n",
        "\n",
        "    def call(self, inputs):\n",
        "        x = self.dense_proj(inputs)\n",
        "        return self.dense_output(x)\n",
        "\n",
        "\n",
        "class VariationalAutoEncoder(keras.Model):\n",
        "    \"\"\"Combines the encoder and decoder into an end-to-end model for training.\"\"\"\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        original_dim,\n",
        "        intermediate_dim=64,\n",
        "        latent_dim=32,\n",
        "        name=\"autoencoder\",\n",
        "        **kwargs\n",
        "    ):\n",
        "        super(VariationalAutoEncoder, self).__init__(name=name, **kwargs)\n",
        "        self.original_dim = original_dim\n",
        "        self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)\n",
        "        self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)\n",
        "\n",
        "    def call(self, inputs):\n",
        "        z_mean, z_log_var, z = self.encoder(inputs)\n",
        "        reconstructed = self.decoder(z)\n",
        "        # Add KL divergence regularization loss.\n",
        "        kl_loss = -0.5 * tf.reduce_mean(\n",
        "            z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1\n",
        "        )\n",
        "        self.add_loss(kl_loss)\n",
        "        return reconstructed\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uBAILhN56qaB"
      },
      "source": [
        "MNISTで簡単なトレーニングループを書いてみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3WDjPVGOvvMR"
      },
      "outputs": [],
      "source": [
        "original_dim = 784\n",
        "vae = VariationalAutoEncoder(original_dim, 64, 32)\n",
        "\n",
        "optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
        "mse_loss_fn = tf.keras.losses.MeanSquaredError()\n",
        "\n",
        "loss_metric = tf.keras.metrics.Mean()\n",
        "\n",
        "(x_train, _), _ = tf.keras.datasets.mnist.load_data()\n",
        "x_train = x_train.reshape(60000, 784).astype(\"float32\") / 255\n",
        "\n",
        "train_dataset = tf.data.Dataset.from_tensor_slices(x_train)\n",
        "train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)\n",
        "\n",
        "epochs = 2\n",
        "\n",
        "# Iterate over epochs.\n",
        "for epoch in range(epochs):\n",
        "    print(\"Start of epoch %d\" % (epoch,))\n",
        "\n",
        "    # Iterate over the batches of the dataset.\n",
        "    for step, x_batch_train in enumerate(train_dataset):\n",
        "        with tf.GradientTape() as tape:\n",
        "            reconstructed = vae(x_batch_train)\n",
        "            # Compute reconstruction loss\n",
        "            loss = mse_loss_fn(x_batch_train, reconstructed)\n",
        "            loss += sum(vae.losses)  # Add KLD regularization loss\n",
        "\n",
        "        grads = tape.gradient(loss, vae.trainable_weights)\n",
        "        optimizer.apply_gradients(zip(grads, vae.trainable_weights))\n",
        "\n",
        "        loss_metric(loss)\n",
        "\n",
        "        if step % 100 == 0:\n",
        "            print(\"step %d: mean loss = %.4f\" % (step, loss_metric.result()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yUWIKs8Nx70G"
      },
      "source": [
        "VAEは`Model`をサブクラス化しているため、組み込みのトレーニングループが備わっています。 そのため、次のようにトレーニングすることもできます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cuadnNIqW0Jb"
      },
      "outputs": [],
      "source": [
        "vae = VariationalAutoEncoder(784, 64, 32)\n",
        "\n",
        "optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
        "\n",
        "vae.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())\n",
        "vae.fit(x_train, x_train, epochs=2, batch_size=64)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RiujG1u0qLp0"
      },
      "source": [
        "## オブジェクト指向開発を超えて：Functional API\n",
        "\n",
        "ここでは、オブジェクト指向の開発の例を多く紹介しましたが、 [Functional API](https://www.tensorflow.org/guide/keras/functional/)を使用してモデルを構築することもできます。重要なのは、いずれかのスタイルを選択しても、もう一方のスタイルで記述されたコンポーネントを組み合わせることができるということです。\n",
        "\n",
        "たとえば、以下のFunctional APIの例では、上の例で定義した `Sampling` レイヤーを再利用しています。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "el7h9ajsqCR9"
      },
      "outputs": [],
      "source": [
        "original_dim = 784\n",
        "intermediate_dim = 64\n",
        "latent_dim = 32\n",
        "\n",
        "# Define encoder model.\n",
        "original_inputs = tf.keras.Input(shape=(original_dim,), name=\"encoder_input\")\n",
        "x = layers.Dense(intermediate_dim, activation=\"relu\")(original_inputs)\n",
        "z_mean = layers.Dense(latent_dim, name=\"z_mean\")(x)\n",
        "z_log_var = layers.Dense(latent_dim, name=\"z_log_var\")(x)\n",
        "z = Sampling()((z_mean, z_log_var))\n",
        "encoder = tf.keras.Model(inputs=original_inputs, outputs=z, name=\"encoder\")\n",
        "\n",
        "# Define decoder model.\n",
        "latent_inputs = tf.keras.Input(shape=(latent_dim,), name=\"z_sampling\")\n",
        "x = layers.Dense(intermediate_dim, activation=\"relu\")(latent_inputs)\n",
        "outputs = layers.Dense(original_dim, activation=\"sigmoid\")(x)\n",
        "decoder = tf.keras.Model(inputs=latent_inputs, outputs=outputs, name=\"decoder\")\n",
        "\n",
        "# Define VAE model.\n",
        "outputs = decoder(z)\n",
        "vae = tf.keras.Model(inputs=original_inputs, outputs=outputs, name=\"vae\")\n",
        "\n",
        "# Add KL divergence regularization loss.\n",
        "kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)\n",
        "vae.add_loss(kl_loss)\n",
        "\n",
        "# Train.\n",
        "optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
        "vae.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())\n",
        "vae.fit(x_train, x_train, epochs=3, batch_size=64)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VsMiP9ABJHWF"
      },
      "source": [
        "詳細につきましては、[Functional API のガイド](https://www.tensorflow.org/guide/keras/functional/)を参照してください。"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "custom_layers_and_models.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
